#=============================================================================================

Solve for policy functions using EconPDEs package

Denote V value function  V = ∫ e^{-ρ t} C_t^{1-γ} * 1/(1-γ)
Instead of solving for V(W), we solve for p(w) defined as
V(W) = (p(w) * Y)^(1-γ) * 1/(1-γ) where w = A / Y
Moreover, relative to traditional HJB, we write HJB divided by (1-γ) * V / p = p(w)^(-γ) * Y^(1-γ)

For retired/unemployment Y should be understood as permanent income not counting the income shock of retirement / unemployment. This is so that, if Y is defined this way never jumps, w never jumps, which makes it easier to write the system of PDEs

For bequest final condition is V(W) = b *  (m.r + (m.ρ- m.r) / m.γ)^(- 1 / m.γ))^(-γ) *  w^{1-γ}Y^(1-γ) / (1-γ)
which can be written as final condition on  p(w)^(1-γ) = b * (m.r + (m.ρ- m.r) / m.γ)^(- 1 / m.γ))^(-γ)*w^{1-γ}
=============================================================================================#

Base.@kwdef mutable struct IncompleteMarketsModel
    μ::Float64 = 0.01             # geometric drift of labor income process
    σ::Float64 = 0.1              # geometric volatility of labor income process
 
    μR::Float64 = 0.0             # geometric drift of labor income process when retired
    σR::Float64 = 0.0             # geometric volatility of labor income process when retired
    χR::Float64 = 0.7             # fraction of income when retired

    λEU::Float64 = -log(1 - 0.05) # hazard rate employment -> unemployment
    λUE::Float64 = -log(1 - 0.8)  # hazard rate unemployment -> employment
    χU::Float64 = 0.6             # fraction of income when unemployed (i.e. unemployment insurance)
 
    r::Float64 = log(1+0.05)      # interest rate
    γ::Float64 = 1.5              # household RRA
    ρ::Float64 = 0.048            # household SDR
    b::Float64 = 0.02             # bequest preference can go down to 1e-5

    wmin::Float64 = 0.0           # lower limit asset grid (borrowing limit)
    wmax::Float64 = 1000.0        # upper limit asset grid
end

# First solve for policy function of the old
function solve_model_old(m, state::NamedTuple, y::NamedTuple)
    (; μ, σ, μR, σR, χR, λEU, λUE, χU, r, γ, ρ, b, wmin, wmax) = m
    (; w) = state
    (; pR, pRw_up, pRw_down, pRww) = y

    # Write HJB of retired
    pRw = pRw_up
    iter = 0
    @label start
    pRw = max(pRw, sqrt(eps()))
    cR = pR * pRw^(-1/γ)
    μRw = (r - μR + σR^2) * w + χR - cR
    σRw = - w * σR
    if (iter == 0) & (μRw <= 0)
        iter += 1
        pRw = pRw_down
        @goto start
    end
    if (w ≈ wmin) && (μRw <= 0.0)
        μRw = 0.0
        cR = χR
        pRw = (cR / pR)^(-γ)
    end
    if (w ≈ wmax)
        pRww = 0.0
    end
    pRt = -  (cR * pRw / (1 - γ) + (μR - 0.5 * γ * σR^2) * pR  + ((r - μR + γ * σR^2) * w + χR - cR) * pRw  + 0.5 * σR^2 * w^2 * (pRww - γ * pRw^2 / pR) - ρ  * pR / (1 - γ))
    return (; pRt), (; w, pR, pRw, pRww, μRw, σRw, cR)
end

# First solve for policy function of the young
function solve_model_young(m::IncompleteMarketsModel, state::NamedTuple, y::NamedTuple)
    (; μ, σ, μR, σR, χR, λEU, λUE, χU, r, γ, ρ, b, wmin, wmax) = m
    (; w) = state
    (; pE, pEw_up, pEw_down, pEww, pU, pUw_up, pUw_down, pUww) = y

    # HJB of employed
    pEw = pEw_up
    iter = 0
    @label startE
    pEw = max(pEw, sqrt(eps()))
    cE = pE * pEw^(-1/γ)
    μEw = (r - μ + σ^2) * w + 1 - cE
    σEw = - w * σ
    if (iter == 0) & (μEw <= 0)
        iter += 1
        pEw = pEw_down
        @goto startE
    end
    if (w ≈ wmin) && (μEw <= 0.0)
        μEw = 0.0
        cE = 1
        pEw = (cE / pE)^(-γ)
    end
    if (w ≈ wmax)
        pEww = 0.0
    end
    pEt = -  (cE * pEw / (1 - γ) + (μ - 0.5 * γ * σ^2) * pE + ((r - μ + γ * σ^2) * w + 1 - cE) * pEw + 0.5 * σ^2 * w^2 * (pEww - γ * pEw^2 / pE) + λEU * pE / (1 -γ) * (max(pU, sqrt(eps()))^(1-γ) / max(pE, sqrt(eps()))^(1-γ) - 1) - ρ  * pE / (1 - γ))


    # HJB of unemployed
    pUw = pUw_up
    iter = 0
    @label startU
    pUw = max(pUw, sqrt(eps()))
    cU = pU * pUw^(-1/γ)
    μUw = (r - μ + σ^2) * w + χU - cU
    σUw = - w * σ
    if (iter == 0) & (μUw <= 0)
        iter += 1
        pUw = pUw_down
        @goto startU
    end
    if (w ≈ wmin) && (μUw <= 0.0)
        μUw = 0.0
        cU = χU
        pUw = (cU / pU)^(-γ)
    end
    if (w ≈ wmax)
        pUww = 0.0
    end
    pUt = -  (cU * pUw / (1 - γ) + (μ - 0.5 * γ * σ^2) * pU + ((r - μ + γ * σ^2) * w + χU - cU) * pUw + 0.5 * σ^2 * w^2 * (pUww - γ * pUw^2 / pU) + λUE * pU / (1 -γ) * (max(pE, sqrt(eps()))^(1-γ) / max(pU, sqrt(eps()))^(1-γ) - 1) - ρ * pU / (1 - γ))

    return (; pEt, pUt), (; pE, pEw, pEww, μEw, σEw, cE, pU, pUw, pUww, μUw, σUw, cU)
end

function solve_model(m::IncompleteMarketsModel, stategrid, τs_young, τs_old,  yend; verbose = true)
    # solve for old
    y_old, residual_norm_old, result_old = pdesolve((state, y) -> solve_model_old(m, state, y), stategrid, yend, τs_old; verbose = verbose)
    all(residual_norm_old .<= 1e-6) || @warn "residual norm not zero"
    result = OrderedDict()
    for key in keys(result_old[1])
        result[key] = zeros(length(stategrid[:w]), length(τs_old))
        for i in eachindex(τs_old)
            result[key][:, i] = result_old[i][key]
        end
    end

    # solve for young
    yend = OrderedDict(
        :pE => result[:pR][:, 1],
        :pU => result[:pR][:, 1])
    y_young, residual_norm_young, result_young = pdesolve((state, y) -> solve_model_young(m, state, y), stategrid, yend, τs_young; verbose = verbose)
    all(residual_norm_old .<= 1e-6) || @warn "residual norm not zero"
    for key in keys(result_young[1])
        result[key] = zeros(length(stategrid[:w]), length(τs_young))
        for i in eachindex(τs_young)
            result[key][:, i] = result_young[i][key]
        end
    end
    return result
end

#=============================================================================================

Solve for density

=============================================================================================#

# solve density starting from initial distribuiton of assets ψ
function solve_density(m::IncompleteMarketsModel, stategrid, τs_young, τs_old, result, ψ)
    (; μ, σ, μR, σR, χR, λEU, λUE, χU, r, γ, ρ, b, wmin, wmax) = m
    gE = zeros(size(result[:cE]))
    gU = zeros(size(result[:cU]))
    gR = zeros(size(result[:cR]))
    gYE = zeros(size(result[:cE]))
    gYU = zeros(size(result[:cU]))
    gYR = zeros(size(result[:cR]))
    curr = vcat(λUE ./ (λEU .+ λUE) .* ψ, λEU ./ (λEU .+ λUE) .* ψ)
    currY = vcat(λUE ./ (λEU .+ λUE) .* ψ, λEU ./ (λEU .+ λUE) .* ψ)
    wn = length(stategrid[:w])
    for it in 1:size(result[:cE], 2)
        gE[:, it] = curr[1:wn]
        gU[:, it] = curr[(wn+1):end]
        TE = generator(DiffusionProcess(stategrid[:w], result[:μEw][:, it], result[:σEw][:, it]))
        TU = generator(DiffusionProcess(stategrid[:w], result[:μUw][:, it], result[:σUw][:, it]))
        T = jointoperator([TE, TU], [-λEU λEU; λUE -λUE])
        curr = (I - sparse(T)' * step(τs_young)) \ curr
        gYE[:, it] = currY[1:wn]
        gYU[:, it] = currY[(wn+1):end]
        TYE = μ * I + generator(DiffusionProcess(stategrid[:w], result[:μEw][:, it] .+ result[:σEw][:, it] .* σ, result[:σEw][:, it]))
        TYU = μ * I + generator(DiffusionProcess(stategrid[:w], result[:μUw][:, it] .+ result[:σUw][:, it] .* σ, result[:σUw][:, it]))
        TY = jointoperator([TYE, TYU], [-λEU λEU; λUE -λUE])
        currY = (I - sparse(TY)' * step(τs_young)) \ currY
    end

    # note that the backward induction for policies repeats the last one. As a result, also repeat densities
    curr = vec(gE[:, end] .+ gU[:, end])
    currY =  vec(gYE[:, end] .+ gYU[:, end])
    for it in 1:size(result[:cR], 2)
        gR[:, it] = curr
        TR = generator(DiffusionProcess(stategrid[:w], result[:μRw][:, it], result[:σRw][:, it]))
        curr = (I - TR' * step(τs_old)) \ curr

        gYR[:, it] = currY
        TYR =  μR * I + generator(DiffusionProcess(stategrid[:w], result[:μRw][:, it] .+ result[:σRw][:, it] .* σR, result[:σRw][:, it]))
        currY = (I - TYR' * step(τs_old)) \ currY
    end
    return gE, gU, gR, gYE, gYU, gYR
end

# solve fixed point of density so that average asset of death = averaged assets of newborn
function solve_density(m::IncompleteMarketsModel, stategrid, τs_young, τs_old, result; tol = tol, verbose = true)
    (; μ, σ, μR, σR, χR, λEU, λUE, χU, r, γ, ρ, b, wmin, wmax) = m
    # we want people to have approximately 0.7 of varage wealth
    # and average wealth to be 6 times labor income
    # so this implies 
    oldAend = 0.7 * 6
    for i in 1:100
        verbose && println("Try $(i): guess asset bequeathed is $(oldAend)")
        Ψ = 0.5 .* (1 .+ erf.((log.(stategrid[:w]) .- (log(oldAend) - 0.5.^2/2)) ./ (sqrt(2) * 0.5)))
        ψ = [Ψ[i] - Ψ[max(1, i-1)] for i in 1:length(stategrid[:w])]
        ψ .= ψ ./ sum(ψ)
        gE, gU, gR, gYE, gYU, gYR = solve_density(m, stategrid, τs_young, τs_old, result, ψ)
        Aend = stategrid[:w]' * gYR[:, end]
        if abs(Aend - oldAend) <= tol
            return gE, gU, gR, gYE, gYU, gYR
        end
        oldAend, Aend = Aend, oldAend
    end
end

#=============================================================================================

Compute intermediated values used in calibbration

=============================================================================================#

# actual labor income (not permanent one) (relative to initial labor income)
function compute_laborincome(m::IncompleteMarketsModel,  stategrid, τs_young, τs_old, gYE, gYU, gYR)
    πR = size(gYR, 2) / (size(gYE, 2) + size(gYR, 2))
    πEU = size(gYE, 2) / (size(gYE, 2) + size(gYR, 2))
    return πR * m.χR * mean(sum(g) for g in eachcol(gYR)) +  πEU * (mean(sum(g) for g in eachcol(gYE)) + mean(sum(g) for g in eachcol(gYU)))
end

function compute_wealth(m::IncompleteMarketsModel,  stategrid, τs_young, τs_old, gYE, gYU, gYR)
    πR = size(gYR, 2) / (size(gYE, 2) + size(gYR, 2))
    πEU = size(gYE, 2) / (size(gYE, 2) + size(gYR, 2))
    return πR * mean(dot(g, stategrid[:w]) for g in eachcol(gYR)) +  πEU * (mean(dot(g, stategrid[:w]) for g in eachcol(gYE)) + mean(dot(g, stategrid[:w]) for g in eachcol(gYU)))
end

function compute_moments(m::IncompleteMarketsModel, ρ, b; dt = 0.25, wn = 1000, tol = 1e-2)
    m = deepcopy(m)
    m.ρ = ρ
    m.b = b
    stategrid = OrderedDict(:w => range(m.wmin^(1/3), m.wmax^(1/3), length = wn).^3)
    τs_old = range(65, 85, step = dt)
    τs_young = range(20, first(τs_old), step = dt)
    yend = OrderedDict(
        :pR => m.b^(1/(1-m.γ)) .* (m.r + (m.ρ - m.r) / m.γ)^(1/(1 - 1 / m.γ)) * (stategrid[:w] .+ 0.01)
        )
    result = solve_model(m, stategrid, τs_young,  τs_old, yend; verbose = false)
    gE, gU, gR, gYE, gYU, gYR = solve_density(m, stategrid, τs_young, τs_old, result; tol = tol, verbose = false)
    laborincome = compute_laborincome(m,  stategrid, τs_young, τs_old, gYE, gYU, gYR)
    wealth = compute_wealth(m,  stategrid, τs_young, τs_old, gYE, gYU, gYR)
    # std of income

    wealth_old = dot(gYR[:, end], stategrid[:w])
    return [wealth / laborincome, wealth_old / wealth]
end

function calibrate(m; wealth_income_ratio = 6, oldwealth_ratio = 0.7, dt = 0.25, wn = 1000, tol = 1e-2)
    good_guess = [m.ρ, m.b]
    lower = [0.03, 0.001]
    upper = [0.1, 0.2]
    res = optimize(y -> norm(compute_moments(m, y[1], max(y[2], 0.01), dt = dt, wn = wn, tol = tol) ./ [wealth_income_ratio, oldwealth_ratio] .- 1), good_guess)
    @show Optim.minimizer(res)
    (ρ = Optim.minimizer(res)[1], b = Optim.minimizer(res)[2])
    return m
end

#=============================================================================================

Solve for welfare gains

=============================================================================================#

function solve_welfare(m::IncompleteMarketsModel, stategrid, τs_young, τs_old, result; ϕ = 0.05)
    (; μ, σ, μR, σR, χR, λEU, λUE, χU, r, γ, ρ, b, wmin, wmax) = m

    # compute expected purchases at each horizon in the grid ts
    wgR_baseline = zeros(size(result[:cR]))
    wn = length(stategrid[:w])
    curr = zeros(wn)
    for it in size(result[:cR], 2):(-1):1
        wgR_baseline[:, it] = curr
        TR =  μR * I + generator(DiffusionProcess(stategrid[:w], result[:μRw][:, it] .+ result[:σRw][:, it] .* σR, result[:σRw][:, it]))
        curr = (I + ((r + ϕ) * I - TR) .* step(τs_old)) \ (curr .+ (.- (r .* stategrid[:w] .+ χR .- result[:cR][:, it])) .* step(τs_old))
    end

    # compute expected wg at_baseline each horizon in the grid ts
    wgE_baseline = zeros(size(result[:cE]))
    wgU_baseline = zeros(size(result[:cU]))
    curr = vcat(wgR_baseline[:, 1], wgR_baseline[:, 1])
    for it in size(result[:cE], 2):(-1):1
        wgE_baseline[:, it] = curr[1:wn]
        wgU_baseline[:, it] = curr[(wn+1):end]
        TE =  μ * I + generator(DiffusionProcess(stategrid[:w], result[:μEw][:, it] .+ result[:σEw][:, it] .* σ, result[:σEw][:, it]))
        TU =  μ * I + generator(DiffusionProcess(stategrid[:w], result[:μUw][:, it] .+ result[:σUw][:, it] .* σ, result[:σUw][:, it]))
        T = jointoperator([TE, TU], [-λEU λEU; λUE -λUE])       
        curr = (I + ((r + ϕ) * I - sparse(T)) .* step(τs_young)) \ (curr .+ (.- vcat(r .* stategrid[:w] .+ 1 .- result[:cE][:, it], r .* stategrid[:w] .+ χU .- result[:cU][:, it])) .* step(τs_young))
    end

    # compute expected purchases at each horizon in the grid ts
    wgR_stochastic = zeros(size(result[:cR]))
    curr = zeros(wn)
    for it in size(result[:cR], 2):(-1):1
        wgR_stochastic[:, it] = curr ./ result[:cR][:, it].^(-γ)
        TR =  - ρ * I + ((1 - γ) * μR + 0.5 * (1 - γ) * (-γ) * σR^2) * I + generator(DiffusionProcess(stategrid[:w], result[:μRw][:, it] .+ result[:σRw][:, it] .* (1 - γ) * σR, result[:σRw][:, it]))
        curr = (I + (ϕ * I - TR) .* step(τs_old)) \ (curr .+ result[:cR][:, it].^(-γ) .* (.- (r .* stategrid[:w] .+ χR .- result[:cR][:, it])) .* step(τs_old))
    end

    wgE_stochastic = zeros(size(result[:cE]))
    wgU_stochastic = zeros(size(result[:cU]))
    curr = vcat(result[:cR][:, 1].^(-γ) .* wgR_stochastic[:, 1], result[:cR][:, 1].^(-γ) .* wgR_stochastic[:, 1])
    for it in size(result[:cE], 2):(-1):1
        wgE_stochastic[:, it] = curr[1:wn] ./ result[:cE][:, it].^(-γ)
        wgU_stochastic[:, it] = curr[(wn+1):end] ./ result[:cU][:, it].^(-γ)
        TE =  - ρ * I + ((1 - γ) * μ + 0.5 * (1 - γ) * (-γ) * σ^2) * I + generator(DiffusionProcess(stategrid[:w], result[:μEw][:, it] .+ result[:σEw][:, it] .* (1 - γ) * σ, result[:σEw][:, it]))
        TU =  - ρ * I + ((1 - γ) * μ + 0.5 * (1 - γ) * (-γ) * σ^2) * I + generator(DiffusionProcess(stategrid[:w], result[:μUw][:, it] .+ result[:σUw][:, it] .* (1 - γ) * σ, result[:σUw][:, it]))
        T = jointoperator([TE, TU], [-λEU λEU; λUE -λUE])       
        curr = (I + (ϕ * I - sparse(T)) .* step(τs_young)) \ (curr .+ vcat(result[:cE][:, it].^(-γ) .* (.- (r .* stategrid[:w] .+ 1 .- result[:cE][:, it])), result[:cU][:, it].^(-γ) .* (.- (r .* stategrid[:w] .+ χU .- result[:cU][:, it]))) .* step(τs_young))
    end
    return wgR_baseline, wgE_baseline, wgU_baseline, wgR_stochastic, wgE_stochastic, wgU_stochastic
end

#=============================================================================================

Explore other calibrations

=============================================================================================#

function solve_experiment(;χU = 0.6, σ = 0.1, ρ = 0.055, γ = 2, ϕ = 0.05)
    m = IncompleteMarketsModel(ρ = ρ, γ = γ, χU = χU, σ = σ)
    stategrid = OrderedDict(:w => range(m.wmin^(1/3), m.wmax^(1/3), length = 1000).^3)
    τs_old = range(65, 85, step = 0.25)
    τs_young = range(20, first(τs_old), step = 0.25)
    yend = OrderedDict(
        :pR => m.b^(1/(1-m.γ)) .* (m.r + (m.ρ - m.r) / m.γ)^(1/(1 - 1 / m.γ)) * (stategrid[:w] .+ 0.01)
        )
    result = solve_model(m, stategrid, τs_young,  τs_old, yend; verbose = false)

    w0 = 3.0
    i0 = searchsortedfirst(stategrid[:w], w0)
    ψ = zeros(length(stategrid[:w]))
    ψ[i0] = 1.0
    gE, gU, gR, gYE, gYU, gYR = solve_density(m, stategrid, τs_young, τs_old, result, ψ)
    wgR_baseline, wgE_baseline, wgU_baseline, wgR_stochastic, wgE_stochastic, wgU_stochastic = solve_welfare(m, stategrid, τs_young, τs_old, result; ϕ = ϕ)

    τB = range(0, 20, step = 0.25)
    wgEU_baseline_mean = sum(wgE_baseline .* gYE, dims = 1) + sum(wgU_baseline .* gYU, dims = 1)
    wgEU_stochastic_mean = sum(wgE_stochastic .* gYE, dims = 1) + sum(wgU_stochastic .* gYU, dims = 1)
    wgB_baseline_mean = exp.((m.r + ϕ) .* (τB .- τB[end]))' .* wgEU_baseline_mean[1]
    wgB_stochastic_mean = exp.((m.r + ϕ) .* (τB .- τB[end]))' .* wgEU_stochastic_mean[1]

    πR = size(result[:pR], 2) / (size(result[:pE], 2) + size(result[:pR], 2) + length(τB))
    πEU = size(result[:pE], 2) / (size(result[:pE], 2) + size(result[:pR], 2) + length(τB))
    πB = length(τB) / (size(result[:pE], 2) + size(result[:pR], 2) + length(τB))

    wg_baseline_meanabs = πR * mean(sum(abs.(wgR_baseline) .* gYR, dims = 1)) +  πEU * mean(sum(abs.(wgE_baseline) .* gYE, dims = 1)) + πEU * mean(sum(abs.(wgU_baseline) .* gYU, dims = 1)) + πB * mean(wgB_baseline_mean)
    wg_stochastic_meanabs = πR * mean(sum(abs.(wgR_stochastic) .* gYR, dims = 1)) +  πEU * mean(sum(abs.(wgE_stochastic) .* gYE, dims = 1)) + πEU * mean(sum(abs.(wgU_stochastic) .* gYU, dims = 1)) + πB * mean(abs.(wgB_stochastic_mean))

    wg_rmse = sqrt(πR * mean(sum((wgR_stochastic .- wgR_baseline).^2 .* gYR, dims = 1)) +  πEU * mean(sum((wgE_stochastic .- wgE_baseline).^2 .* gYE, dims = 1)) + πEU * mean(sum((wgU_stochastic .- wgU_baseline ).^2 .* gYU, dims = 1)) + πB * mean((wgB_stochastic_mean .- wgB_baseline_mean).^2))

    wg_cov = πR * mean(sum(wgR_stochastic .* wgR_baseline .* gYR, dims = 1)) +  πEU * mean(sum(wgE_stochastic .* wgE_baseline .* gYE, dims = 1)) + πEU * mean(sum(wgU_stochastic .* wgU_baseline .* gYU, dims = 1)) + πB * mean(wgB_stochastic_mean .* wgB_baseline_mean)
    wg_var1 = πR * mean(sum(wgR_stochastic.^2 .* gYR, dims = 1)) +  πEU * mean(sum(wgE_stochastic.^2 .* gYE, dims = 1)) + πEU * mean(sum(wgU_stochastic.^2 .* gYU, dims = 1)) + πB * mean(wgB_stochastic_mean.^2)
    wg_var2 = πR * mean(sum(wgR_baseline.^2 .* gYR, dims = 1)) +  πEU * mean(sum(wgE_baseline.^2 .* gYE, dims = 1)) + πEU * mean(sum(wgU_baseline.^2 .* gYU, dims = 1)) + πB * mean(wgB_baseline_mean.^2)
    wg_corr = wg_cov / sqrt(wg_var1 * wg_var2)

    return  wg_baseline_meanabs, wg_stochastic_meanabs, wg_rmse, wg_corr
end